import argparse
import os
import pickle
import random

import gym
import numpy as np
from skimage.transform import resize
from IPython import embed

import common_args
from envs import darkroom_env, bandit_env_new
from ctrls.ctrl_bandit_new import (
    ThompsonSamplingPolicy, 
    UnifPolicy, 
    UCBPolicy,
    LinUCBPolicy,
    LinUCBPolicy_wt,
    OptDesPolicy,
    UnifPolicy_wt,
    ThompsonSamplingPolicy_wt,
)
from evals import eval_bandit_new
from utils import (
    build_bandit_data_filename,
    build_linear_bandit_data_filename,
    build_darkroom_data_filename,
    build_miniworld_data_filename,
)

import ipdb
from tqdm import tqdm
import joblib
from joblib import Parallel, delayed

def rollin_bandit(env, cov, orig=False):
    H = env.H_context
    opt_a_index = env.opt_a_index
    xs, us, xps, rs = [], [], [], []

    exp = False
    if exp == False:
        cov = np.random.choice([0.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0])
        alpha = np.ones(env.dim)
        probs = np.random.dirichlet(alpha)
        probs2 = np.zeros(env.dim)
        rand_index = np.random.choice(np.arange(env.dim))
        probs2[rand_index] = 1.0
        probs = (1 - cov) * probs + cov * probs2
    else:
        raise NotImplementedError

    for h in range(H):
        x = np.array([1])
        u = np.zeros(env.dim)
        i = np.random.choice(np.arange(env.dim), p=probs)
        u[i] = 1.0
        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)

    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs


def rollin_linear_bandit_vec(envs):
    H = envs[0].H_context

    # data generated by thompson sampling policy.
    prior_mean = 0.0
    prior_var = 1.0

    thmp = ThompsonSamplingPolicy(
        envs[0],
        std=envs[0].var,
        sample=True,
        prior_mean=prior_mean,
        prior_var=prior_var,
        warm_start=False,
        batch_size=len(envs)
    )

    vec_env = bandit_env.BanditEnvVec(envs)
    _, meta = eval_bandit_new.deploy_online_vec(vec_env, thmp, H, include_meta=True)
    context_states = meta['context_states']
    context_actions = meta['context_actions']
    context_next_states = meta['context_next_states']
    context_rewards = meta['context_rewards'][:, :, 0]

    return context_states, context_actions, context_next_states, context_rewards


def rollin_linear_bandit_vec_unif(envs):
    H = envs[0].H_context

    # data generated by uniform policy.

    print("batch_size:", len(envs))
    unif = UnifPolicy(
        envs[0],
        batch_size=len(envs)
    )

    vec_env = bandit_env.BanditEnvVec(envs)
    _, meta = eval_bandit_new.deploy_online_vec(vec_env, unif, H, include_meta=True)
    context_states = meta['context_states']
    context_actions = meta['context_actions']
    context_next_states = meta['context_next_states']
    context_rewards = meta['context_rewards'][:, :, 0]

    return context_states, context_actions, context_next_states, context_rewards


def rollin_linear_bandit_vec_custom(envs, controller_type = "thompson", train_type = "opt"):
    H = envs[0].H_context

    # data generated by uniform policy.

    if controller_type == "unif":
        print("Collecting data through Uniform")
        print("batch_size:", len(envs))
        controller = UnifPolicy(
            envs[0],
            batch_size=len(envs)
        )
    elif controller_type == "ucb":
        print("Collecting data through UCB")
        print("batch_size:", len(envs))
        controller = UCBPolicy(
            envs[0],
            batch_size=len(envs)
        )
    elif controller_type == "thompson":
        print("Collecting data through Thompson Sampling")
        print("batch_size:", len(envs))

        prior_mean = 0.0
        prior_var = 1.0

        controller = ThompsonSamplingPolicy(
        envs[0],
        std=envs[0].var,
        sample=True,
        prior_mean=prior_mean,
        prior_var=prior_var,
        warm_start=False,
        batch_size=len(envs)
    )
    elif controller_type == "linucb":
        print("Collecting data through LinUCB")
        print("batch_size:", len(envs))

        controller = LinUCBPolicy(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )
    elif controller_type in ["linucb_wt", "linucb_original", "linucb_pred_reward", "linucb_pred_reward_opt_a"]:
        print("Collecting data through LinUCB")
        print("batch_size:", len(envs))

        controller = LinUCBPolicy_wt(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )
    elif controller_type in ["unif_wt", "unif_original", "unif_pred_reward", "unif_pred_reward_opt_a"]:
        print("Collecting data through Unif")
        print("batch_size:", len(envs))

        controller = UnifPolicy_wt(
            envs[0],
            batch_size=len(envs)
    )

    elif controller_type in ["TS_wt", "TS_original", "TS_pred_reward", "TS_pred_reward_opt_a"]:
        print("Collecting data through Thompson Sampling")
        print("batch_size:", len(envs))

        prior_mean = 0.0
        prior_var = 1.0

        controller = ThompsonSamplingPolicy_wt(
        envs[0],
        std=envs[0].var,
        sample=True,
        prior_mean=prior_mean,
        prior_var=prior_var,
        warm_start=False,
        batch_size=len(envs)
    )

    elif controller_type in ["optdes_wt", "optdes_original", "optdes_pred_reward", "optdes_pred_reward_opt_a"]:
        print("Collecting data through Optimal Design")
        print("batch_size:", len(envs))

        controller = OptDesPolicy(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )

        print("Calculating optimal design")
        optimal_design = np.zeros((len(envs), envs[0].dim))

        ## Arm features are same for all envs, so calculate optimal design for one env and use it for all envs
        optimal_design[0] = controller.calculate_opt_design(envs[0].arms)
        for i in tqdm(range(1,len(envs))):
            optimal_design[i] = optimal_design[0]
        

    elif controller_type == "linucb_optdes":
        print("Collecting data through LinUCB and Optimal design")
        print("batch_size:", len(envs))

        controller = LinUCBPolicy(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )
        
        controller2 = OptDesPolicy(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )

        print("Calculating optimal design")
        optimal_design = np.zeros((len(envs), envs[0].dim))

        ## Arm features are same for all envs, so calculate optimal design for one env and use it for all envs
        optimal_design[0] = controller2.calculate_opt_design(envs[0].arms)
        for i in tqdm(range(1,len(envs))):
            optimal_design[i] = optimal_design[0]
    
    elif controller_type == "TS_optdes":
        print("Collecting data through TS and Optimal design")
        print("batch_size:", len(envs))

        prior_mean = 0.0
        prior_var = 1.0

        controller = ThompsonSamplingPolicy(
        envs[0],
        std=envs[0].var,
        sample=True,
        prior_mean=prior_mean,
        prior_var=prior_var,
        warm_start=False,
        batch_size=len(envs)
    )
        
        controller2 = OptDesPolicy(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )

        print("Calculating optimal design")
        optimal_design = np.zeros((len(envs), envs[0].dim))

        ## Arm features are same for all envs, so calculate optimal design for one env and use it for all envs
        optimal_design[0] = controller2.calculate_opt_design(envs[0].arms)
        for i in tqdm(range(1,len(envs))):
            optimal_design[i] = optimal_design[0]
    
    elif controller_type == "TS_LinUCB":
        print("Collecting data through TS and LinUCB")
        print("batch_size:", len(envs))

        prior_mean = 0.0
        prior_var = 1.0

        controller = ThompsonSamplingPolicy(
        envs[0],
        std=envs[0].var,
        sample=True,
        prior_mean=prior_mean,
        prior_var=prior_var,
        warm_start=False,
        batch_size=len(envs)
    )
        
        controller2 = LinUCBPolicy(
        envs[0],
        const=1.0,
        batch_size=len(envs)
    )
    
    
    vec_env = bandit_env_new.BanditEnvVec(envs)
    if train_type == "opt":
        _, meta = eval_bandit_new.deploy_online_vec(vec_env, controller, H, include_meta=True)
    elif train_type == "lookahead":   
        _, meta = eval_bandit_new.deploy_online_vec_lookahead(vec_env, controller, H, include_meta=True)
    elif train_type == "lookahead_wt":   
        _, meta = eval_bandit_new.deploy_online_vec_lookahead_wt(vec_env, controller, H, include_meta=True)
    elif train_type == "lookahead_mix":   
        epsilon = 0.3
        if controller_type == "TS_LinUCB":
            _, meta = eval_bandit_new.deploy_online_vec_lookahead_mix(vec_env, controller, controller2, H, epsilon = epsilon, include_meta=True)
        else:
            _, meta = eval_bandit_new.deploy_online_vec_lookahead_mix(vec_env, controller, controller2, H, epsilon = epsilon, optimal_design=optimal_design, include_meta=True)
    
    elif train_type == "train_original":
        collect_data_config = "true_opt"

        if data_type == "optdes_original":
            _, meta = eval_bandit_new.deploy_online_vec_train_original(vec_env, controller, H, include_meta=True, optimal_design=optimal_design)
        else:
            _, meta = eval_bandit_new.deploy_online_vec_train_original(vec_env, controller, H, collect_data_config, include_meta=True)
    elif train_type == "train_original_emp_opt":

        collect_data_config = "emp_opt"

        if data_type == "optdes_original":
            _, meta = eval_bandit_new.deploy_online_vec_train_original(vec_env, controller, H, include_meta=True, optimal_design=optimal_design)
        else:
            _, meta = eval_bandit_new.deploy_online_vec_train_original(vec_env, controller, H, collect_data_config, include_meta=True)
    
    elif train_type == "lookahead_pred_reward":
        if data_type == "optdes_pred_reward":
            _, meta = eval_bandit_new.deploy_online_vec_lookahead_pred_reward(vec_env, controller, H, include_meta=True, optimal_design=optimal_design)
        else:
            _, meta = eval_bandit_new.deploy_online_vec_lookahead_pred_reward(vec_env, controller, H, include_meta=True)
    elif train_type == "lookahead_pred_reward_opt_a":
        _, meta = eval_bandit_new.deploy_online_vec_lookahead_pred_reward_opt_a(vec_env, controller, H, include_meta=True)

    
    context_states = meta['context_states']
    context_actions = meta['context_actions']
    context_next_states = meta['context_next_states']
    context_rewards = meta['context_rewards'][:, :, 0]
    context_opt_actions = meta['context_opt_actions']

    if train_type == "lookahead_wt":   
        
        context_sum_rewards = meta['context_sum_rewards'] # returns emp means of all actions in batch
        return context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards

    if train_type == "train_original" or train_type == "train_original_emp_opt":   
        
        context_sum_rewards = meta['context_sum_rewards'] # returns emp means of all actions in batch
        context_pred_rewards = meta['context_pred_rewards'] # returns pred rewards of all actions in batch
        return context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards, context_pred_rewards
    
    
    if train_type == "lookahead_pred_reward":   
        
        context_sum_rewards = meta['context_sum_rewards'] # returns emp means of all actions in batch
        context_pred_rewards = meta['context_pred_rewards'] # returns pred rewards of all actions in batch
        return context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards, context_pred_rewards

    if train_type == "lookahead_pred_reward_opt_a":   
        
        context_sum_rewards = meta['context_sum_rewards'] # returns emp means of all actions in batch
        context_pred_rewards = meta['context_pred_rewards'] # returns pred rewards of all actions in batch
        context_pred_opt_a = meta['context_pred_opt_a'] # returns pred optimal action of all envs in batch
        return context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards, context_pred_rewards, context_pred_opt_a


    return context_states, context_actions, context_next_states, context_rewards, context_opt_actions





def rand_pos_and_dir(env):
    pos_vec = np.random.uniform(0, env.size, size=3)
    pos_vec[1] = 0.0
    dir_vec = np.random.uniform(0, 2 * np.pi)
    return pos_vec, dir_vec





def generate_bandit_histories_from_envs(envs, n_hists, n_samples, cov):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_bandit(env, cov=cov)
            for k in range(n_samples):
                query_state = np.array([1])
                optimal_action = env.opt_a

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'means': env.means,
                }
                trajs.append(traj)
    return trajs



def generate_bandit_histories(n_envs, dim, horizon, var, **kwargs):
    envs = [bandit_env.sample(dim, horizon, var)
            for _ in range(n_envs)]
    trajs = generate_bandit_histories_from_envs(envs, **kwargs)
    return trajs


def generate_linear_bandit_histories(n_envs, dim, lin_d, horizon, var, exclude_arm = False, include_arm = -1, train_type = "opt", **kwargs):
    # generate fixed features for arms of all linear bandits
    
    rng = np.random.RandomState(seed=1234)
    # rng = np.random.RandomState(seed=1541)
    arms = rng.normal(size=(dim, lin_d)) / np.sqrt(lin_d)

    if exclude_arm == False:
        if include_arm == -1:
            print("No exclusion/inclusion of arms")
            envs = [bandit_env_new.sample_linear(arms, horizon, var)
                    for _ in range(n_envs)]

        else:
            print("Inclusion of arms")
            include_arm = 1
            # print(n_envs)
            a = [i for i in [1,2,4] if i != include_arm]
            a1_ = np.random.choice(a, size = int(.8*n_envs))
            
            a = [include_arm]
            a2_ = np.random.choice(a, size = (n_envs - int(.8*n_envs)))
            arm_list = np.concatenate([a1_,a2_],)
            np.random.shuffle(arm_list)
            arm_list = list(arm_list)

            # print("arm list", arm_list)

            envs = [bandit_env_new.sample_linear_include(arms, horizon, var, include_arm = arm_list[i])
                    for i in range(n_envs)]
    else:
        print("Exclusion of arms")
        envs = [bandit_env_new.sample_linear_exclude(arms, horizon, var)
                for _ in range(n_envs)]

    ###### Uncomment when running a new environment #####
    # data_type = kwargs['data_type']
    # n_hists = kwargs['n_hists']
    # n_samples = kwargs['n_samples']

    print("Generating histories...")
    if data_type=='thompson':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all = [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)

        

    elif data_type=='unif' or data_type=='ucb':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all = [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            
        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
    
    elif data_type=='linucb':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all = [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
    
    elif data_type=='linucb_optdes' or data_type=='TS_optdes' or data_type=='TS_LinUCB':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all = [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)

    
    trajs = []
    for i, env in enumerate(envs):
        print('Generating linear bandit histories for env {}/{}'.format(i+1, n_envs))
        for j in range(n_hists):
            if data_type=='uniform':
                context_states, context_actions, context_next_states, context_rewards, context_opt_actions = rollin_bandit(env, cov=0.0)
            elif data_type=='thompson' or 'unif' or 'ucb':
                context_states = context_states_all[i, j]
                context_actions = context_actions_all[i, j]
                context_next_states = context_next_states_all[i, j]
                context_rewards = context_rewards_all[i, j]
                context_opt_actions = context_opt_actions_all[i, j]
            # elif data_type=='unif':
            #     context_states = context_states_all[i, j]
            #     context_actions = context_actions_all[i, j]
            #     context_next_states = context_next_states_all[i, j]
            #     context_rewards = context_rewards_all[i, j]
            else:
                raise ValueError("Invalid data type")

            for k in range(n_samples):
                query_state = np.array([1])
                optimal_action = env.opt_a
                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'context_opt_actions': context_opt_actions,
                    'means': env.means,
                    'arms': arms,
                    'theta': env.theta,
                    'var': env.var,
                }
                trajs.append(traj)
    return trajs





def generate_linear_bandit_histories_wt(n_envs, dim, lin_d, horizon, var, exclude_arm = False, include_arm = -1, train_type = "opt", pred_reward_type = "linear", **kwargs):
    # generate fixed features for arms of all linear bandits
    
    rng = np.random.RandomState(seed=1234)
    # rng = np.random.RandomState(seed=1541)
    arms = rng.normal(size=(dim, lin_d)) / np.sqrt(lin_d)

    if exclude_arm == False:
        if include_arm == -1:
            print("No exclusion/inclusion of arms")

            if pred_reward_type == "linear":
                envs = [bandit_env_new.sample_linear(arms, horizon, var)
                        for _ in range(n_envs)]
            elif pred_reward_type == "censored":
                envs = [bandit_env_new.sample_linear_censored(arms, horizon, var)
                        for _ in range(n_envs)]
            elif pred_reward_type == "non_linear":
                envs = [bandit_env_new.sample_linear_nlm(arms, horizon, var)
                        for _ in range(n_envs)]
            elif pred_reward_type == "bilinear":
                
                # rng = np.random.RandomState(seed=1234)
                # arms_left = rng.normal(size=(dim, lin_d)) / np.sqrt(lin_d)
                # arms_right = rng.normal(size=(dim, lin_d)) / np.sqrt(lin_d)
                envs = [bandit_env_new.sample_bilinear(arms, horizon, var, rank = kwargs['rank'])
                        for _ in range(n_envs)]
            elif pred_reward_type == "latent":
                envs = [bandit_env_new.sample_latent(arms, horizon, var, rank = kwargs['rank'])
                        for _ in range(n_envs)]
            elif pred_reward_type == "new_arms_linear":
                
                num_new_arms = kwargs['new_arms']
                print("New arms", num_new_arms)

                # First create dim-1 arms same, then add 1 or 5 new arm randomly to each new env
                rng = np.random.RandomState(seed=1234)
                # arms = rng.normal(size=(dim-1, lin_d)) / np.sqrt(lin_d)
                arms = rng.normal(size=(dim-num_new_arms, lin_d)) / np.sqrt(lin_d)
                envs = []
                arms_total = np.zeros((n_envs, dim, lin_d))
                for i in range(n_envs):
                    rng = np.random.RandomState()

                    # Then add the 1 or 5 new arms randomly
                    # arm_new = rng.normal(size=(1, lin_d)) / np.sqrt(lin_d)
                    arm_new = rng.normal(size=(num_new_arms, lin_d)) / np.sqrt(lin_d)


                    arms_total[i] = np.concatenate([arms, arm_new], axis = 0)
                    # print("New arm added", arm_new)
                    # print("All arms", arms_total[i], np.shape(arms_total[i]))
                    envs.append(bandit_env_new.sample_linear(arms_total[i], horizon, var))
            elif pred_reward_type == "new_arms_non_linear":
                
                num_new_arms = kwargs['new_arms']
                print("New arms", num_new_arms)
                
                # First create dim-1 arms same, then add 1 or 5 new arm randomly to each new env
                rng = np.random.RandomState(seed=1234)
                # arms = rng.normal(size=(dim-1, lin_d)) / np.sqrt(lin_d)
                arms = rng.normal(size=(dim-num_new_arms, lin_d)) / np.sqrt(lin_d)
                envs = []
                arms_total = np.zeros((n_envs, dim, lin_d))
                for i in range(n_envs):
                    rng = np.random.RandomState()

                    # Then add the 1 or 5 new arms randomly
                    # arm_new = rng.normal(size=(1, lin_d)) / np.sqrt(lin_d)
                    arm_new = rng.normal(size=(num_new_arms, lin_d)) / np.sqrt(lin_d)


                    arms_total[i] = np.concatenate([arms, arm_new], axis = 0)
                    # print("New arm added", arm_new)
                    # print("All arms", arms_total[i], np.shape(arms_total[i]))
                    envs.append(bandit_env_new.sample_linear_nlm(arms_total[i], horizon, var))
                
                
                
        else:
            print("Inclusion of arms")
            include_arm = 1
            # print(n_envs)
            a = [i for i in [1,2,4] if i != include_arm]
            a1_ = np.random.choice(a, size = int(.8*n_envs))
            
            a = [include_arm]
            a2_ = np.random.choice(a, size = (n_envs - int(.8*n_envs)))
            arm_list = np.concatenate([a1_,a2_],)
            np.random.shuffle(arm_list)
            arm_list = list(arm_list)

            # print("arm list", arm_list)

            envs = [bandit_env_new.sample_linear_include(arms, horizon, var, include_arm = arm_list[i])
                    for i in range(n_envs)]
    else:
        print("Exclusion of arms")
        envs = [bandit_env_new.sample_linear_exclude(arms, horizon, var)
                for _ in range(n_envs)]

    ###### Uncomment when running a new environment #####
    # data_type = kwargs['data_type']
    # n_hists = kwargs['n_hists']
    # n_samples = kwargs['n_samples']

    print("Generating histories...")
    if data_type=='thompson':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all = [], [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)

        

    elif data_type=='unif' or data_type=='ucb':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all = [], [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)

    elif data_type=='linucb':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all = [], [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)
    
    elif data_type=='linucb_wt' or data_type=='unif_wt' or data_type=='TS_wt' or data_type=='optdes_wt':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all = [], [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)

    elif data_type=='linucb_original' or data_type=='unif_original' or data_type=='TS_original' or data_type=='optdes_original':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all, context_pred_rewards_all = [], [], [], [], [], [], []
        for j in range(n_hists):
            
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards, context_pred_rewards = rollin_linear_bandit_vec_custom(envs, data_type, train_type)
            
            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)
            context_pred_rewards_all.append(context_pred_rewards)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)
        context_pred_rewards_all = np.stack(context_pred_rewards_all, axis=1)

    elif data_type=='linucb_pred_reward' or data_type=='unif_pred_reward' or data_type=='TS_pred_reward' or data_type=='optdes_pred_reward':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all, context_pred_rewards_all = [], [], [], [], [], [], []
        
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards, context_pred_rewards = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)
            context_pred_rewards_all.append(context_pred_rewards)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)
        context_pred_rewards_all = np.stack(context_pred_rewards_all, axis=1)
    
    elif data_type=='linucb_pred_reward_opt_a' or data_type=='unif_pred_reward_opt_a' or data_type=='TS_pred_reward_opt_a' or data_type=='optdes_pred_reward_opt_a':
        context_states_all, context_actions_all, context_next_states_all, context_rewards_all, context_opt_actions_all, context_sum_rewards_all, context_pred_rewards_all, context_pred_opt_a_all = [], [], [], [], [], [], [], []
        for j in range(n_hists):
            context_states, context_actions, context_next_states, context_rewards, context_opt_actions, context_sum_rewards, context_pred_rewards, context_pred_opt_a = rollin_linear_bandit_vec_custom(envs, data_type, train_type)

            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)
            context_opt_actions_all.append(context_opt_actions)
            context_sum_rewards_all.append(context_sum_rewards)
            context_pred_rewards_all.append(context_pred_rewards)
            context_pred_opt_a_all.append(context_pred_opt_a)

        context_states_all = np.stack(context_states_all, axis=1)
        context_actions_all = np.stack(context_actions_all, axis=1)
        context_next_states_all = np.stack(context_next_states_all, axis=1)
        context_rewards_all = np.stack(context_rewards_all, axis=1)
        context_opt_actions_all = np.stack(context_opt_actions_all, axis=1)
        context_sum_rewards_all = np.stack(context_sum_rewards_all, axis=1)
        context_pred_rewards_all = np.stack(context_pred_rewards_all, axis=1)
        context_pred_opt_a_all = np.stack(context_pred_opt_a_all, axis=1)
    
    trajs = []
    for i, env in enumerate(envs):
        print('Generating linear bandit histories for env {}/{}'.format(i+1, n_envs))
        
        for j in range(n_hists):
            if data_type=='uniform':
                context_states, context_actions, context_next_states, context_rewards, context_opt_actions = rollin_bandit(env, cov=0.0)
                
            
            elif data_type=='thompson' or data_type=='unif' or data_type=='ucb' or data_type=='TS_wt' or data_type=='linucb_wt' or data_type=='unif_wt' or data_type=='optdes_wt':
                context_states = context_states_all[i, j]
                context_actions = context_actions_all[i, j]
                context_next_states = context_next_states_all[i, j]
                context_rewards = context_rewards_all[i, j]
                context_opt_actions = context_opt_actions_all[i, j]
                context_sum_rewards = context_sum_rewards_all[i, j]
                
                if pred_reward_type == "new_arms_linear" or pred_reward_type == "new_arms_non_linear":
                    arms = arms_total[i]

                for k in range(n_samples):
                    query_state = np.array([1])
                    optimal_action = env.opt_a
                    traj = {
                        'query_state': query_state,
                        'optimal_action': optimal_action,
                        'context_states': context_states,
                        'context_actions': context_actions,
                        'context_next_states': context_next_states,
                        'context_rewards': context_rewards,
                        'context_opt_actions': context_opt_actions,
                        'context_sum_rewards': context_sum_rewards,
                        'means': env.means,
                        'arms': arms,
                        'theta': env.theta,
                        'var': env.var,
                    }
                    trajs.append(traj)

            elif data_type=='TS_original' or data_type=='linucb_original' or data_type=='unif_original' or data_type=='optdes_original':
                
                context_states = context_states_all[i, j]
                context_actions = context_actions_all[i, j]
                context_next_states = context_next_states_all[i, j]
                context_rewards = context_rewards_all[i, j]
                context_opt_actions = context_opt_actions_all[i, j]
                context_sum_rewards = context_sum_rewards_all[i, j]
                context_pred_rewards = context_pred_rewards_all[i, j]

                
                if pred_reward_type == "new_arms_linear" or pred_reward_type == "new_arms_non_linear":
                    arms = arms_total[i]
                
                for k in range(n_samples):
                    query_state = np.array([1])
                    optimal_action = env.opt_a
                    traj = {
                        'query_state': query_state,
                        'optimal_action': optimal_action,
                        'context_states': context_states,
                        'context_actions': context_actions,
                        'context_next_states': context_next_states,
                        'context_rewards': context_rewards,
                        'context_opt_actions': context_opt_actions,
                        'context_sum_rewards': context_sum_rewards,
                        'context_pred_rewards': context_pred_rewards,
                        'means': env.means,
                        'arms': arms,
                        'theta': env.theta,
                        'var': env.var,
                    }
                    trajs.append(traj)
                    # print("context_actions", np.shape(context_actions))

            elif data_type=='TS_pred_reward' or data_type=='linucb_pred_reward' or data_type=='unif_pred_reward' or data_type=='optdes_pred_reward':
                
                context_states = context_states_all[i, j]
                context_actions = context_actions_all[i, j]
                context_next_states = context_next_states_all[i, j]
                context_rewards = context_rewards_all[i, j]
                context_opt_actions = context_opt_actions_all[i, j]
                context_sum_rewards = context_sum_rewards_all[i, j]
                context_pred_rewards = context_pred_rewards_all[i, j]

                if pred_reward_type == "new_arms_linear" or pred_reward_type == "new_arms_non_linear":
                    arms = arms_total[i]

                for k in range(n_samples):
                    query_state = np.array([1])
                    optimal_action = env.opt_a
                    traj = {
                        'query_state': query_state,
                        'optimal_action': optimal_action,
                        'context_states': context_states,
                        'context_actions': context_actions,
                        'context_next_states': context_next_states,
                        'context_rewards': context_rewards,
                        'context_opt_actions': context_opt_actions,
                        'context_sum_rewards': context_sum_rewards,
                        'context_pred_rewards': context_pred_rewards,
                        'means': env.means,
                        'arms': arms,
                        'theta': env.theta,
                        'var': env.var,
                    }
                    trajs.append(traj)

            elif data_type=='TS_pred_reward_opt_a' or data_type=='linucb_pred_reward_opt_a' or data_type=='unif_pred_reward_opt_a' or data_type=='optdes_pred_reward_opt_a':
                
                context_states = context_states_all[i, j]
                context_actions = context_actions_all[i, j]
                context_next_states = context_next_states_all[i, j]
                context_rewards = context_rewards_all[i, j]
                context_opt_actions = context_opt_actions_all[i, j]
                context_sum_rewards = context_sum_rewards_all[i, j]
                context_pred_rewards = context_pred_rewards_all[i, j]
                context_pred_opt_a = context_pred_opt_a_all[i, j]

                if pred_reward_type == "new_arms_linear" or pred_reward_type == "new_arms_non_linear":
                    arms = arms_total[i]

                for k in range(n_samples):
                    query_state = np.array([1])
                    optimal_action = env.opt_a
                    traj = {
                        'query_state': query_state,
                        'optimal_action': optimal_action,
                        'context_states': context_states,
                        'context_actions': context_actions,
                        'context_next_states': context_next_states,
                        'context_rewards': context_rewards,
                        'context_opt_actions': context_opt_actions,
                        'context_sum_rewards': context_sum_rewards,
                        'context_pred_rewards': context_pred_rewards,
                        'context_pred_opt_a': context_pred_opt_a,
                        'means': env.means,
                        'arms': arms,
                        'theta': env.theta,
                        'var': env.var,
                    }
                    trajs.append(traj)
            
                
            # elif data_type=='unif':
            #     context_states = context_states_all[i, j]
            #     context_actions = context_actions_all[i, j]
            #     context_next_states = context_next_states_all[i, j]
            #     context_rewards = context_rewards_all[i, j]
            else:
                raise ValueError("Invalid data type")



            # for k in range(n_samples):
            #     query_state = np.array([1])
            #     optimal_action = env.opt_a
            #     traj = {
            #         'query_state': query_state,
            #         'optimal_action': optimal_action,
            #         'context_states': context_states,
            #         'context_actions': context_actions,
            #         'context_next_states': context_next_states,
            #         'context_rewards': context_rewards,
            #         'context_opt_actions': context_opt_actions,
            #         'context_sum_rewards': context_sum_rewards,
            #         'means': env.means,
            #         'arms': arms,
            #         'theta': env.theta,
            #         'var': env.var,
            #     }
            #     trajs.append(traj)
    return trajs


######################################

def rollin_mdp(env, rollin_type):
    states = []
    actions = []
    next_states = []
    rewards = []

    state = env.reset()
    for _ in range(env.horizon):
        if rollin_type == 'uniform':
            state = env.sample_state()
            action = env.sample_action()
        elif rollin_type == 'expert':
            action = env.opt_action(state)
        else:
            raise NotImplementedError
        next_state, reward = env.transit(state, action)

        states.append(state)
        actions.append(action)
        next_states.append(next_state)
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)

    return states, actions, next_states, rewards


def rand_pos_and_dir(env):
    pos_vec = np.random.uniform(0, env.size, size=3)
    pos_vec[1] = 0.0
    dir_vec = np.random.uniform(0, 2 * np.pi)
    return pos_vec, dir_vec


def rollin_mdp_miniworld(env, horizon, rollin_type, target_shape=(25, 25, 3)):
    observations = []
    pos_and_dirs = []
    actions = []
    rewards = []

    for _ in range(horizon):
        if rollin_type == 'uniform':
            init_pos, init_dir = rand_pos_and_dir(env)
            env.place_agent(pos=init_pos, dir=init_dir)

        obs = env.render_obs()
        obs = resize(obs, target_shape, anti_aliasing=True)
        observations.append(obs)
        pos_and_dirs.append(np.concatenate(
            [env.agent.pos[[0, -1]], env.agent.dir_vec[[0, -1]]]))

        if rollin_type == 'uniform':
            action = np.random.randint(env.action_space.n)
        elif rollin_type == 'expert':
            action = env.opt_a(obs, env.agent.pos, env.agent.dir_vec)
        else:
            raise ValueError("Invalid rollin type")
        _, rew, _, _, _ = env.step(action)
        a_zero = np.zeros(env.action_space.n)
        a_zero[action] = 1

        actions.append(a_zero)
        rewards.append(rew)

    observations = np.array(observations)
    states = np.array(pos_and_dirs)[..., 2:]    # only use dir, not pos
    actions = np.array(actions)
    rewards = np.array(rewards)
    return observations, states, actions, rewards

def generate_mdp_histories_from_envs(envs, n_hists, n_samples, rollin_type):
    trajs = []
    
    # ipdb.set_trace()
    for env in tqdm(envs, desc="Processing environments", unit="env"):
    # for env in envs:
        for j in range(n_hists):
            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_mdp(env, rollin_type=rollin_type)
            for k in range(n_samples):
                query_state = env.sample_state()
                optimal_action = env.opt_action(query_state)

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': env.goal,
                }

                # Add perm_index for DarkroomEnvPermuted
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index

                trajs.append(traj)
    return trajs


def generate_mdp_histories_from_envs_custom(envs, n_hists, n_samples, rollin_type):
    trajs = []
    
    
    # ipdb.set_trace()
    for env in tqdm(envs, desc="Processing environments", unit="env"):
    # for env in envs:
        for j in range(n_hists):
            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_mdp(env, rollin_type=rollin_type)
            for k in range(n_samples):
                query_state = env.sample_state()
                optimal_action = env.opt_action(query_state)

                # ipdb.set_trace()

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': env.goal,
                }

                # Add perm_index for DarkroomEnvPermuted
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index

                trajs.append(traj)

        
    return trajs



def generate_darkroom_histories(goals, dim, horizon, **kwargs):
    envs = [darkroom_env.DarkroomEnv(dim, goal, horizon) for goal in goals]
    trajs = generate_mdp_histories_from_envs(envs, **kwargs)
    return trajs

def generate_darkroom_permuted_histories(indices, dim, horizon, **kwargs):
    envs = [darkroom_env.DarkroomEnvPermuted(
        dim, index, horizon) for index in indices]
    trajs = generate_mdp_histories_from_envs(envs, **kwargs)
    return trajs


def generate_darkroom_histories_custom(goals, dim, horizon, **kwargs):
    envs = [darkroom_env.DarkroomEnv(dim, goal, horizon) for goal in goals]
    trajs = generate_mdp_histories_from_envs_custom(envs, **kwargs)
    
    
    for traj in trajs:
            
        context_opt_actions = np.zeros((horizon, dim+1))
        context_sum_rewards = np.zeros(horizon, dtype=np.float32)
        context_pred_rewards = np.zeros(horizon, dtype=np.float32)

        context_states = traj['context_states']
        context_actions = traj['context_actions']
        context_rewards = traj['context_rewards']

        for h in range(horizon-1):
            sum_reward_forwards = context_rewards[h]
            gamma_h = 1.0
            for h_f in range(h+1, horizon):
                # ipdb.set_trace()
                sum_reward_forwards += gamma_h*context_rewards[h_f]
                gamma_h *= 1.0
            
            # ipdb.set_trace()
            context_opt_actions[h] = context_actions[h+1]
            context_sum_rewards[h] = sum_reward_forwards
        context_opt_actions[horizon-1] = context_actions[horizon-1]
        context_sum_rewards[horizon-1] = context_sum_rewards[horizon-1]
        
        context_pred_rewards = context_sum_rewards

        traj['context_opt_actions'] = context_opt_actions
        traj['context_sum_rewards'] = context_sum_rewards
        traj['context_pred_rewards'] = context_pred_rewards
    
    
    return trajs

def generate_darkroom_permuted_histories_custom(indices, dim, horizon, **kwargs):
    envs = [darkroom_env.DarkroomEnvPermuted(
        dim, index, horizon) for index in indices]
    trajs = generate_mdp_histories_from_envs_custom(envs, **kwargs)
    return trajs

def generate_miniworld_histories(env_ids, image_dir, n_hists, n_samples, horizon, target_shape, rollin_type='uniform'):
    if not os.path.exists(image_dir):
        os.makedirs(image_dir, exist_ok=True)

    n_envs = len(env_ids)
    env = gym.make('MiniWorld-OneRoomS6FastMultiFourBoxesFixedInit-v0')
    obs = env.reset()

    trajs = []
    for i, env_id in enumerate(env_ids):
        print(f"Generating histories for env {i}/{n_envs}")
        env.set_task(env_id)
        env.reset()
        for j in range(n_hists):
            (
                context_images,
                context_states,
                context_actions,
                context_rewards,
            ) = rollin_mdp_miniworld(env, horizon, rollin_type=rollin_type, target_shape=target_shape)
            filepath = f'{image_dir}/context{i}_{j}.npy'
            np.save(filepath, context_images)

            for _ in range(n_samples):
                init_pos, init_dir = rand_pos_and_dir(env)
                env.place_agent(pos=init_pos, dir=init_dir)
                obs = env.render_obs()
                obs = resize(obs, target_shape, anti_aliasing=True)

                action = env.opt_a(obs, env.agent.pos, env.agent.dir_vec)
                one_hot_action = np.zeros(env.action_space.n)
                one_hot_action[action] = 1

                traj = {
                    'query_image': obs,
                    'query_state': env.agent.dir_vec[[0, -1]], # only use dir, not pos
                    'optimal_action': one_hot_action,
                    'context_images': filepath,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_states,  # unused
                    'context_rewards': context_rewards,
                    'env_id': env_id,  # not used during training, only used for evaling in correct env
                }
                trajs.append(traj)
    return trajs


####################################



if __name__ == '__main__':
    np.random.seed(0)
    random.seed(0)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    args = vars(parser.parse_args())
    print("Args: ", args)

    env = args['env']
    n_envs = args['envs']
    n_eval_envs = args['envs_eval']
    n_hists = args['hists']
    n_samples = args['samples']
    horizon = args['H']
    dim = args['dim']
    var = args['var']
    cov = args['cov']
    env_id_start = args['env_id_start']
    env_id_end = args['env_id_end']
    lin_d = args['lin_d']
    data_type = args['data_type']


    n_train_envs = int(.8 * n_envs)
    n_test_envs = n_envs - n_train_envs


    ######### For new arms setting #########
    new_arms = args['new_arms']

    config = {
        'n_hists': n_hists,
        'n_samples': n_samples,
        'horizon': horizon,
    }

    if env == 'bandit':
        # config.update({'dim': dim, 'var': var, 'cov': cov, 'type': 'uniform'})
        config.update({'dim': dim, 'var': var, 'cov': cov})

        train_trajs = generate_bandit_histories(n_train_envs, **config)
        test_trajs = generate_bandit_histories(n_test_envs, **config)
        eval_trajs = generate_bandit_histories(n_eval_envs, **config)

        train_filepath = build_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_bandit_data_filename(env, n_eval_envs, config, mode=2)

    elif env == 'linear_bandit':
        # config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': 'thompson'})
        # config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': 'unif'})
        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type})

        train_trajs = generate_linear_bandit_histories(n_train_envs, 
                        exclude_arm = False,
                        **config)
        test_trajs = generate_linear_bandit_histories(n_test_envs, 
                        exclude_arm = False,
                        **config)
        eval_trajs = generate_linear_bandit_histories(n_eval_envs, 
                        exclude_arm = False,
                        **config)

        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)

    elif env == 'linear_bandit_exclude':
        # config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': 'thompson'})
        # config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': 'unif'})
        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type})

        train_trajs = generate_linear_bandit_histories(n_train_envs, 
                        exclude_arm = True,
                        **config)
        test_trajs = generate_linear_bandit_histories(n_test_envs, 
                        exclude_arm = False,
                        include_arm = True,
                        **config)
        eval_trajs = generate_linear_bandit_histories(n_eval_envs, 
                        exclude_arm = False,
                        include_arm = True,
                        **config)

        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)

    elif env == 'linear_bandit_new_train':

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type})

        train_trajs = generate_linear_bandit_histories(n_train_envs, 
                        exclude_arm = False,
                        **config)
        test_trajs = generate_linear_bandit_histories(n_test_envs, 
                        exclude_arm = False,
                        **config)
        eval_trajs = generate_linear_bandit_histories(n_eval_envs, 
                        exclude_arm = False,
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    elif env == 'linear_bandit_train_lookahead':

        pred_reward_type = args['pred_reward_type']

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type, 'pred_reward_type': pred_reward_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})

        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 

        train_trajs = generate_linear_bandit_histories(n_train_envs, 
                        exclude_arm = False,
                        train_type = "lookahead",
                        **config)
        test_trajs = generate_linear_bandit_histories(n_test_envs, 
                        exclude_arm = False,
                        train_type = "lookahead",
                        **config)
        eval_trajs = generate_linear_bandit_histories(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "lookahead",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    elif env == 'linear_bandit_train_lookahead_wt':

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})

        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 

        train_trajs = generate_linear_bandit_histories_wt(n_train_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_wt",
                        **config)
        test_trajs = generate_linear_bandit_histories_wt(n_test_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_wt",
                        **config)
        eval_trajs = generate_linear_bandit_histories_wt(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_wt",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    elif env == 'linear_bandit_train_original':

        pred_reward_type = args['pred_reward_type']

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type, 'pred_reward_type': pred_reward_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})

        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 

        train_trajs = generate_linear_bandit_histories_wt(n_train_envs, 
                        exclude_arm = False,
                        train_type = "train_original",
                        **config)
        test_trajs = generate_linear_bandit_histories_wt(n_test_envs, 
                        exclude_arm = False,
                        train_type = "train_original",
                        **config)
        eval_trajs = generate_linear_bandit_histories_wt(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "train_original",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    elif env == 'linear_bandit_train_original_emp_opt':

        pred_reward_type = args['pred_reward_type']

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type, 'pred_reward_type': pred_reward_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})

        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 

        train_trajs = generate_linear_bandit_histories_wt(n_train_envs, 
                        exclude_arm = False,
                        train_type = "train_original_emp_opt",
                        **config)
        test_trajs = generate_linear_bandit_histories_wt(n_test_envs, 
                        exclude_arm = False,
                        train_type = "train_original_emp_opt",
                        **config)
        eval_trajs = generate_linear_bandit_histories_wt(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "train_original_emp_opt",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    
    elif env == 'linear_bandit_train_lookahead_pred_reward':
        
        # print("starting env", env)
        pred_reward_type = args['pred_reward_type']

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type, 'pred_reward_type': pred_reward_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})
        
        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 

        train_trajs = generate_linear_bandit_histories_wt(n_train_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward",
                        **config)
        test_trajs = generate_linear_bandit_histories_wt(n_test_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward",
                        **config)
        eval_trajs = generate_linear_bandit_histories_wt(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    elif env == 'linear_bandit_train_lookahead_pred_reward_opt_a':

        pred_reward_type = args['pred_reward_type']
        
        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type, 'pred_reward_type': pred_reward_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})
        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 
        train_trajs = generate_linear_bandit_histories_wt(n_train_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward_opt_a",
                        **config)
        test_trajs = generate_linear_bandit_histories_wt(n_test_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward_opt_a",
                        **config)
        eval_trajs = generate_linear_bandit_histories_wt(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward_opt_a",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    
    elif env == 'linear_bandit_train_AD':

        pred_reward_type = args['pred_reward_type']

        ## linear bandit with AD requires train_type as lookahead_pred_reward

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type, 'pred_reward_type': pred_reward_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})
        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 
        train_trajs = generate_linear_bandit_histories_wt(n_train_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward",
                        **config)
        test_trajs = generate_linear_bandit_histories_wt(n_test_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward",
                        **config)
        eval_trajs = generate_linear_bandit_histories_wt(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_pred_reward",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)

    elif env == 'linear_bandit_train_lookahead_mix':

        config.update({'dim': dim, 'lin_d': lin_d, 'var': var, 'cov': cov, 'data_type': data_type})
        if pred_reward_type == 'new_arms_linear' or pred_reward_type == 'new_arms_non_linear':
           config.update({'new_arms': new_arms})

        if pred_reward_type == "latent" or pred_reward_type == "bilinear":
            config.update({'rank': args['rank']}) 

        train_trajs = generate_linear_bandit_histories(n_train_envs, 
                        exclude_arm = False,
                        train_type = "lookahead_mix",
                        **config)
        test_trajs = generate_linear_bandit_histories(n_test_envs, 
                        exclude_arm = False,
                        train_type = "lookahead",
                        **config)
        eval_trajs = generate_linear_bandit_histories(n_eval_envs, 
                        exclude_arm = False,
                        train_type = "lookahead",
                        **config)
        
        train_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=0)
        test_filepath = build_linear_bandit_data_filename(env, n_envs, config, mode=1)
        eval_filepath = build_linear_bandit_data_filename(env, n_eval_envs, config, mode=2)
    

    elif env == 'darkroom_heldout':

        
        config.update({'dim': dim, 'rollin_type': 'uniform'})
        goals = np.array([[(j, i) for i in range(dim)]
                         for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]

        eval_goals = np.array(test_goals.tolist() *
                              int(100 // len(test_goals)))
        train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        train_trajs = generate_darkroom_histories(train_goals, **config)
        test_trajs = generate_darkroom_histories(test_goals, **config)
        eval_trajs = generate_darkroom_histories(eval_goals, **config)

        train_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=0)
        test_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=1)
        eval_filepath = build_darkroom_data_filename(env, 100, config, mode=2)

    elif env == 'darkroom_heldout_lookahead_pred_reward':

        
        config.update({'dim': dim, 'rollin_type': 'uniform'})
        # config.update({'n_envs': n_envs, 'n_hists': n_hists, 'n_samples': n_samples, 'horizon': horizon, 'dim': dim, 'var': var, 'cov': cov, 'data_type': data_type})
        goals = np.array([[(j, i) for i in range(dim)]
                         for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]

        eval_goals = np.array(test_goals.tolist() *
                              int(100 // len(test_goals)))
        train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        train_trajs = generate_darkroom_histories_custom(train_goals, **config)
        test_trajs = generate_darkroom_histories_custom(test_goals, **config)
        eval_trajs = generate_darkroom_histories_custom(eval_goals, **config)

        train_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=0)
        test_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=1)
        eval_filepath = build_darkroom_data_filename(env, 100, config, mode=2)


    elif env == 'miniworld':
        import gymnasium as gym
        import miniworld

        config.update({'rollin_type': 'uniform', 
            'target_shape': (25, 25, 3),
        })

        if env_id_start < 0 or env_id_end < 0:
            env_id_start = 0
            env_id_end = n_envs

        # make sure you don't just generate the same data when batching data collection
        np.random.seed(0 + env_id_start)    
        random.seed(0 + env_id_start)       


        env_ids = np.arange(env_id_start, env_id_end)

        train_test_split = int(.8 * len(env_ids))
        train_env_ids = env_ids[:train_test_split]
        test_env_ids = env_ids[train_test_split:]

        train_filepath = build_miniworld_data_filename(
            env, env_id_start, env_id_end, config, mode=0)
        test_filepath = build_miniworld_data_filename(
            env, env_id_start, env_id_end, config, mode=1)
        eval_filepath = build_miniworld_data_filename(env, 0, 100, config, mode=2)


        train_trajs = generate_miniworld_histories(
            train_env_ids,
            train_filepath.split('.')[0],
            **config)
        test_trajs = generate_miniworld_histories(
            test_env_ids,
            test_filepath.split('.')[0],
            **config)
        eval_trajs = generate_miniworld_histories(
            test_env_ids[:100],
            eval_filepath.split('.')[0],
            **config)


    else:
        raise NotImplementedError

    dir_path = "/media/external/subho/DPT/"

    if not os.path.exists(dir_path + 'datasets'):
        os.makedirs(dir_path + 'datasets', exist_ok=True)
    with open(train_filepath, 'wb') as file:
        pickle.dump(train_trajs, file)
    with open(test_filepath, 'wb') as file:
        pickle.dump(test_trajs, file)
    with open(eval_filepath, 'wb') as file:
        pickle.dump(eval_trajs, file)

    print(f"Saved to {train_filepath}.")
    print(f"Saved to {test_filepath}.")
    print(f"Saved to {eval_filepath}.")
